SHAP analysis¶

In [ ]:
import joblib
from pathlib import Path
from omegaconf import OmegaConf

import pandas as pd
import shap
from sklearn.metrics import classification_report
In [ ]:
EXPERIMENT_ROOT = "../../experiments/rf_features_only"
RESULTS = Path(EXPERIMENT_ROOT) / "results.yaml"

results = OmegaConf.load(RESULTS)
In [ ]:
HAS_SCALER = False
DATA_ROOT = "../../data/prepared/"
MODEL = Path(EXPERIMENT_ROOT) / "model.pkl"
SCALER = Path(EXPERIMENT_ROOT) / "scaler.pkl"
VAL_DATA = Path(DATA_ROOT) / "val_features.pkl"
TEST_DATA = Path(DATA_ROOT) / "test_features.pkl"
VEC_COLS = list(range(768))

val_df = pd.read_pickle(VAL_DATA)
val_df_labels = val_df.retweet_label
val_df.drop(["retweet_label", "id_str"], axis=1, inplace=True)

test_df = pd.read_pickle(TEST_DATA)
test_df_labels = test_df.retweet_label
test_df.drop(["retweet_label", "id_str"], axis=1, inplace=True)

model = joblib.load(MODEL)

if HAS_SCALER:
    scaler = joblib.load(SCALER)
    transformed_val = scaler.transform(val_df[VEC_COLS].values)
    val_df[VEC_COLS] = transformed_val
    transformed_test = scaler.transform(test_df[VEC_COLS].values)
    test_df[VEC_COLS] = transformed_test
In [ ]:
val_df.head()
Out[ ]:
entities.urls entities.media user_in_net has_covid_keyword tweets_keywords_3_in_degree tweets_keywords_3_out_degree tweets_keywords_3_in_strength tweets_keywords_3_out_strength tweets_keywords_3_eigenvector_in tweets_keywords_3_eigenvector_out ... users_reply_clustering user.followers_isna users_mention_isna following_users_isna users_reply_isna log1p_num_hashtags log1p_followers_count log1p_friends_count log1p_statuses_count log1p_num_mentioned
173185 0 1 0 0 -0.647823 -0.510160 -0.967028 0.944173 -0.180504 -0.054619 ... -0.797732 0 0.595178 1 1.481133 1.098612 0.762279 1.036245 -0.241881 0.0
173186 0 1 0 0 -0.647823 -0.651175 -0.967028 -0.954585 -0.180504 -0.054619 ... -0.797732 0 0.595178 1 1.481133 0.693147 0.762279 1.036245 -0.241881 0.0
173187 0 0 0 0 -0.647823 -0.510160 -0.967028 0.944173 -0.180504 -0.054619 ... -0.797732 0 0.595178 1 1.481133 0.000000 0.762279 1.036245 -0.241881 0.0
173188 0 1 0 0 -0.647823 -0.651175 -0.967028 -0.954585 -0.180504 -0.054619 ... -0.797732 0 0.595178 1 1.481133 0.000000 0.762279 1.036245 -0.241881 0.0
173189 0 1 0 0 -0.647823 -0.651175 -0.967028 -0.954585 -0.180504 -0.054619 ... -0.797732 0 0.595178 1 1.481133 0.000000 0.762279 1.036245 -0.241881 0.0

5 rows × 49 columns

In [ ]:
test_df.head()
Out[ ]:
entities.urls entities.media user_in_net has_covid_keyword tweets_keywords_3_in_degree tweets_keywords_3_out_degree tweets_keywords_3_in_strength tweets_keywords_3_out_strength tweets_keywords_3_eigenvector_in tweets_keywords_3_eigenvector_out ... users_reply_clustering user.followers_isna users_mention_isna following_users_isna users_reply_isna log1p_num_hashtags log1p_followers_count log1p_friends_count log1p_statuses_count log1p_num_mentioned
194715 1 1 1 1 1.908895 2.059462 0.938154 0.944173 -0.180504 -0.054602 ... 0.791296 0 0.595178 0 -0.675159 0.693147 1.203002 0.781949 0.956741 0.0
194716 1 1 1 0 -0.106151 -0.375587 1.083082 0.944173 -0.180504 -0.054618 ... 0.791296 0 0.595178 0 -0.675159 0.000000 1.203002 0.781949 0.956741 0.0
194717 1 1 1 0 -0.310304 -0.247161 0.925793 0.944173 -0.180504 -0.054619 ... 0.791296 0 0.595178 0 -0.675159 0.000000 1.203002 0.781949 0.956741 0.0
194718 1 1 1 1 -0.206248 1.744481 0.925793 0.944173 -0.180504 -0.054610 ... 0.791296 0 0.595178 0 -0.675159 0.693147 1.203002 0.781949 0.956741 0.0
194719 1 0 1 0 -0.647823 -0.651175 -0.967028 -0.954585 -0.180504 -0.054619 ... 0.791296 0 0.595178 0 -0.675159 0.693147 1.203002 0.781949 0.956741 0.0

5 rows × 49 columns

Validation and test results¶

In [ ]:
val_predictions = model.predict(val_df.values)
val_out = classification_report(val_df_labels.values, val_predictions,
                                digits=3, output_dict=False)
print(val_out)
              precision    recall  f1-score   support

           0      0.726     0.674     0.699     10954
           1      0.634     0.690     0.661      8989

    accuracy                          0.681     19943
   macro avg      0.680     0.682     0.680     19943
weighted avg      0.685     0.681     0.682     19943

[Parallel(n_jobs=24)]: Using backend ThreadingBackend with 24 concurrent workers.
[Parallel(n_jobs=24)]: Done   2 tasks      | elapsed:    0.0s
[Parallel(n_jobs=24)]: Done 152 tasks      | elapsed:    0.1s
[Parallel(n_jobs=24)]: Done 200 out of 200 | elapsed:    0.1s finished
In [ ]:
test_predictions = model.predict(test_df.values)
test_out = classification_report(test_df_labels.values, test_predictions,
                                 digits=3, output_dict=False)
print(test_out)
              precision    recall  f1-score   support

           0      0.716     0.640     0.676     10639
           1      0.633     0.710     0.669      9305

    accuracy                          0.673     19944
   macro avg      0.675     0.675     0.673     19944
weighted avg      0.677     0.673     0.673     19944

[Parallel(n_jobs=24)]: Using backend ThreadingBackend with 24 concurrent workers.
[Parallel(n_jobs=24)]: Done   2 tasks      | elapsed:    0.0s
[Parallel(n_jobs=24)]: Done 152 tasks      | elapsed:    0.1s
[Parallel(n_jobs=24)]: Done 200 out of 200 | elapsed:    0.1s finished

SHAP Explainer preparation and test data sampling¶

In [ ]:
explainer = shap.Explainer(model)
In [ ]:
len(test_df)
Out[ ]:
19944
In [ ]:
# sample for faster SHAP calculation
# typically, 100-1000 examples is ok
test_df_sample = test_df.sample(frac=0.05, random_state=42)
In [ ]:
test_df_labels_sample = test_df_labels[test_df_sample.index]
In [ ]:
test_df_sample
Out[ ]:
entities.urls entities.media user_in_net has_covid_keyword tweets_keywords_3_in_degree tweets_keywords_3_out_degree tweets_keywords_3_in_strength tweets_keywords_3_out_strength tweets_keywords_3_eigenvector_in tweets_keywords_3_eigenvector_out ... users_reply_clustering user.followers_isna users_mention_isna following_users_isna users_reply_isna log1p_num_hashtags log1p_followers_count log1p_friends_count log1p_statuses_count log1p_num_mentioned
196995 0 1 1 0 -0.647823 -0.651175 -0.967028 -0.954585 -0.180504 -0.054619 ... 1.341809 0 0.595178 0 -0.675159 0.000000 0.001409 0.576489 0.647097 0.0
208764 0 1 0 0 0.639640 1.776070 1.004117 1.055509 -0.180504 -0.054619 ... -0.797732 0 0.595178 1 1.481133 0.000000 -0.203655 -1.429558 -0.643354 0.0
198537 1 0 1 1 0.171867 1.041619 0.925793 0.960493 -0.180504 -0.054619 ... -0.797732 0 -1.680170 0 -0.675159 0.000000 -1.295609 -0.905254 -0.341030 1.0
199457 0 0 1 0 -0.647823 -0.651175 -0.967028 -0.954585 -0.180504 -0.054619 ... 1.103448 0 -1.680170 0 -0.675159 0.000000 -0.436851 0.158218 -1.101759 1.0
204573 0 0 1 0 2.409977 2.188824 1.005758 1.028141 -0.180504 -0.054619 ... -0.797732 0 0.595178 0 -0.675159 2.079442 -0.249862 0.215947 0.440486 0.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
213637 0 1 1 0 1.427123 2.108751 1.037142 1.027553 -0.180504 0.734861 ... -0.797732 0 0.595178 0 -0.675159 1.791759 0.129038 0.559467 -0.564597 0.0
214221 1 0 1 0 -0.206248 1.888641 1.174367 1.064132 -0.180504 -0.054109 ... 1.209000 0 0.595178 0 -0.675159 1.609438 0.684298 1.069249 0.523685 0.0
212261 0 0 1 0 0.257579 2.372401 0.925793 0.990173 -0.180504 0.226447 ... 0.029750 0 0.595178 0 -0.675159 0.000000 -0.090805 0.502197 -0.012185 0.0
213929 0 0 1 0 -0.418476 -0.510160 0.925793 0.944173 -0.180504 -0.054619 ... 1.581774 0 0.595178 0 -0.675159 0.693147 0.130952 -0.651068 0.757100 0.0
198449 1 0 0 0 -0.418476 -0.651175 0.925793 -0.954585 -0.180504 -0.054619 ... -0.797732 0 0.595178 1 1.481133 0.000000 -1.564675 -0.719433 -0.964353 0.0

997 rows × 49 columns

In [ ]:
test_df_labels_sample
Out[ ]:
196995    1
208764    1
198537    0
199457    0
204573    0
         ..
213637    1
214221    1
212261    1
213929    1
198449    1
Name: retweet_label, Length: 997, dtype: int64
In [ ]:
test_df_conf_sample = model.predict_proba(test_df_sample.values)
test_df_conf_sample
[Parallel(n_jobs=24)]: Using backend ThreadingBackend with 24 concurrent workers.
[Parallel(n_jobs=24)]: Done   2 tasks      | elapsed:    0.0s
[Parallel(n_jobs=24)]: Done 152 tasks      | elapsed:    0.0s
[Parallel(n_jobs=24)]: Done 200 out of 200 | elapsed:    0.1s finished
Out[ ]:
array([[0.5306387 , 0.4693613 ],
       [0.39770972, 0.60229028],
       [0.77102733, 0.22897267],
       ...,
       [0.40260657, 0.59739343],
       [0.58668253, 0.41331747],
       [0.70577764, 0.29422236]])
In [ ]:
shap_values = explainer(test_df_sample)
In [ ]:
shap_values.base_values.shape
Out[ ]:
(997, 2)
In [ ]:
shap_values.values.shape
Out[ ]:
(997, 49, 2)
In [ ]:
shap_values.data.shape
Out[ ]:
(997, 49)

SHAP values accross the test data¶

In [ ]:
# visualize the prediction's explanation for class 0 for a confident correct prediction
# note idx 2 in test_df_conf_sample: 0.771 confidence
idx = 2
exp = shap.Explanation(shap_values.values[:, :, 0], shap_values.base_values[:, 0], shap_values.data, test_df_sample)
shap.plots.waterfall(exp[idx], max_display=30)
In [ ]:
# visualize the prediction's explanation for class 1 as a check (symmetric chart)
idx = 2  # the same example
exp = shap.Explanation(shap_values.values[:, :, 1], shap_values.base_values[:, 1], shap_values.data, test_df_sample)
shap.plots.waterfall(exp[idx], max_display=30)
In [ ]:
# the same plot for class 1, but horizontal display
idx = 2
shap.initjs()
shap.force_plot(explainer.expected_value[1], shap_values.values[idx, :, 1], test_df_sample.iloc[idx, :])
Out[ ]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
In [ ]:
# Vizualize multiple predictions (class 1)
# It is possible to explore different variables interactively in the notebook
shap.initjs()
shap.force_plot(explainer.expected_value[1], shap_values.values[:, :, 1], test_df_sample)
Out[ ]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Feature importance¶

In [ ]:
# influence on class 1
shap.summary_plot(shap_values.values[:, :, 1], test_df_sample, plot_type='dot')

Feature importance and correlation¶

In [ ]:
clust = shap.utils.hclust(test_df_sample, test_df_labels_sample, linkage="single")
`early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.
No/low signal found from feature 2 (this is typically caused by constant or near-constant features)! Cluster distances can't be computed for it (so setting all distances to 1).
No/low signal found from feature 3 (this is typically caused by constant or near-constant features)! Cluster distances can't be computed for it (so setting all distances to 1).
 84%|████████▎ | 41/49 [00:11<00:03,  2.35it/s]No/low signal found from feature 40 (this is typically caused by constant or near-constant features)! Cluster distances can't be computed for it (so setting all distances to 1).
No/low signal found from feature 41 (this is typically caused by constant or near-constant features)! Cluster distances can't be computed for it (so setting all distances to 1).
No/low signal found from feature 42 (this is typically caused by constant or near-constant features)! Cluster distances can't be computed for it (so setting all distances to 1).
100%|██████████| 49/49 [00:13<00:00,  3.03it/s]No/low signal found from feature 48 (this is typically caused by constant or near-constant features)! Cluster distances can't be computed for it (so setting all distances to 1).
50it [00:13,  1.04s/it]                        
In [ ]:
exp = shap.Explanation(shap_values.values[:, :, 1], shap_values.base_values[:, 1], shap_values.data, test_df_sample, feature_names=test_df_sample.columns)
shap.plots.bar(exp, max_display=48, clustering=clust, clustering_cutoff=1)

SHAP dependence plots¶

"SHAP dependence plots show the effect of a single feature across the whole dataset. They plot a feature's value vs. the SHAP value of that feature across many samples. SHAP dependence plots are similar to partial dependence plots, but account for the interaction effects present in the features, and are only defined in regions of the input space supported by data. The vertical dispersion of SHAP values at a single feature value is driven by interaction effects, and another feature is chosen for coloring to highlight possible interactions."

https://shap.readthedocs.io/en/latest/example_notebooks/tabular_examples/tree_based_models/Census%20income%20classification%20with%20LightGBM.html

In [ ]:
for name in test_df_sample.columns:
    shap.dependence_plot(name, shap_values.values[:, :, 1], test_df_sample, display_features=test_df_sample)